#!/usr/bin/env python3

"""
Utility to put an AWS EventBridge rule.

This is not really a general purpose utility as it reads its source data from a
specially formatted JSON structure provided by the lava job framework.

Note that this obliterates any existing rule with the same name (rule ID) and
also replaces all targets with those specified in the rule specification file.

"""

from __future__ import annotations

import argparse
import json
import os
import re
import sys
from collections.abc import Iterable
from contextlib import suppress
from fnmatch import fnmatchcase
from typing import Any
from uuid import uuid4

import boto3
from botocore.exceptions import ClientError

__author__ = 'Murray Andrews'
__version__ = '1.0.1'

PROG = os.path.splitext(os.path.basename(sys.argv[0]))[0]

RULE_SPEC_REQUIRED = {'rule_id', 'owner', 'description'}
RULE_SPEC_OPTIONAL = {
    'enabled',
    'event_bus_name',
    'event_pattern',
    'role_arn',
    'schedule_expression',
    'tags',
    'targets',
}

TAG_BAD_CHARS = re.compile(r'[^\w .:+=@_/-]')
TAG_SAFE_CHAR = '_'

# ..............................................................................
# region utilities
# ..............................................................................


# ------------------------------------------------------------------------------
def dict_strip(d: dict) -> dict:
    """
    Return a new dictionary with all None value elements removed.

    :param d:       Input dictionary.
    :return:        New dict with None value keys removed.

    """

    return {k: v for k, v in d.items() if v is not None}


# ------------------------------------------------------------------------------
def glob_strip(names: Iterable[str], patterns: str | Iterable[str]) -> set[str]:
    """
    Remove from an iterable of strings any that match any of the given patterns.

    Pattens are glob style.

    The result is returned as a set so any ordering is lost.

    :param names:       An iterable of strings to match.
    :param patterns:    A glob pattern or iterable of glob patterns.

    :return:            A set containing all input strings that don't match any
                        of the glob patterns.
    """

    names = set(names)
    if isinstance(patterns, str):
        patterns = [str]

    for p in patterns:
        # noinspection PyTypeChecker
        names -= {n for n in names if fnmatchcase(n, p)}

    return names


# ------------------------------------------------------------------------------
def dict_check(
    d: dict[str, Any],
    required: Iterable[str] = None,
    optional: Iterable[str] = None,
    ignore: str | Iterable[str] = None,
) -> None:
    """
    Check that the given dictionary has the required keys.

    :param d:           The dict to check.
    :param required:    An iterable of mandatory keys. Can be None indicating
                        required keys should not be checked.
    :param optional:    An iterable of optional keys. Can be None indicating
                        optional keys should not be checked.
    :param ignore:      Ignore any keys that match the specified glob pattern
                        or list of patterns.

    :raise ValueError:  If the dict doesn't contain all required keys or does
                        contain disallowed keys.
    """

    if required is not None and not isinstance(required, set):
        # noinspection PyTypeChecker
        required = set(required)
    if optional is not None and not isinstance(optional, set):
        optional = set(optional)

    actual_keys = set(d)

    # Remove the ignore keys from everything required and actual keys.
    # No need to remove from optionals.
    if ignore:
        if isinstance(ignore, str):
            ignore = [ignore]
        if required:
            required = glob_strip(required, ignore)
        actual_keys = glob_strip(actual_keys, ignore)

    if required is not None and not required <= actual_keys:
        raise ValueError('Missing keys: {}'.format(', '.join(sorted(required - actual_keys))))

    if optional is not None:
        bad_keys = actual_keys - (required if required is not None else set()) - optional
        if bad_keys:
            raise ValueError('Unexpected keys: {}'.format(', '.join(sorted(bad_keys))))


# ..............................................................................
# endregion utilities
# ..............................................................................


# ------------------------------------------------------------------------------
class ResourceNotFoundError(Exception):
    """Missing AWS resource."""

    pass


# ------------------------------------------------------------------------------
def process_cli_args() -> argparse.Namespace:
    """
    Process the command line arguments.

    :return:    The args namespace.
    """

    argp = argparse.ArgumentParser(
        prog=PROG, description='Create an AWS EventBridge rule from a specification file.'
    )

    argp.add_argument('--profile', action='store', help='As for AWS CLI.')

    argp.add_argument('-v', '--version', action='version', version=__version__)

    argp.add_argument(
        'rule_specs',
        metavar='rule-spec.json',
        nargs='*',
        help='A JSON formatted event rule specification.',
    )

    return argp.parse_args()


# ------------------------------------------------------------------------------
def list_targets_by_rule(
    rule_name: str, events_client, event_bus_name: str = 'default'
) -> list[dict[str, Any]]:
    """
    Get the targets for an event rule.

    :param rule_name:       The rule name.
    :param event_bus_name:  Event bus name. Defaults to 'default'.
    :param events_client:   A boto3 events client.
    :return:                A list of the targets (see boto3 doco).

    :raises ResourceNotFoundError: If the rule doesn't exist.
    """

    try:
        targets = events_client.list_targets_by_rule(Rule=rule_name, EventBusName=event_bus_name)
    except ClientError as e:
        if e.response['Error']['Code'] == 'ResourceNotFoundException':
            raise ResourceNotFoundError(f'RuleName={rule_name} EventBusName={event_bus_name}')
        raise

    return targets['Targets']


# ------------------------------------------------------------------------------
def put_targets_for_rule(
    rule_name: str, targets: list, events_client, event_bus_name: str = 'default'
) -> list[dict[str, str]]:
    """
    Put targets on an EventBridge rule.

    The targets structure is a list of entries that are each one of the following:

    *   An ARN as a string.
    *   A structure exactly matching the requirements of the boto3.events
        put_targets call (camel case and all).

    :param rule_name:       The rule name.
    :param event_bus_name:  Event bus name. Defaults to 'default'.
    :param events_client:   A boto3 events client.
    :param targets:         A list of target structures.
    :return:                The FailedEntries structure from boto3 put_targets().
    """

    expanded_targets = []
    for t in targets:
        if not isinstance(t, str):
            expanded_targets.append(t)
            continue

        if not t.startswith('arn:'):
            raise ValueError(f'Bad target: {t}')

        expanded_targets.append({'Id': f'Id{uuid4()}', 'Arn': t})
    result = events_client.put_targets(
        Rule=rule_name, EventBusName=event_bus_name, Targets=expanded_targets
    )

    return result.get('FailedEntries', [])


# ------------------------------------------------------------------------------
def do_spec_file(filename: str, aws_session: boto3.Session = None) -> None:
    """
    Process a rule specification file to create an EventBridge rule.

    :param filename:        Name of JSON formatted specification file.
    :param aws_session:     Boto3 Session.

    """

    if not aws_session:
        aws_session = boto3.Session()

    events_client = aws_session.client('events')
    with open(filename) as fp:
        rule_spec = json.load(fp)

    if not rule_spec:
        print(f'{filename}: Skipping empty specification file')
        return

    try:
        dict_check(
            rule_spec, required=RULE_SPEC_REQUIRED, optional=RULE_SPEC_OPTIONAL, ignore='x-*'
        )
    except Exception as e:
        raise Exception(f'{filename}: {e}')

    rule_name = rule_spec['rule_id']
    event_bus_name = rule_spec.get('event_bus_name', 'default')
    old_targets = None
    rule_exists = False
    try:
        old_targets = list_targets_by_rule(
            rule_name, event_bus_name=event_bus_name, events_client=events_client
        )
        rule_exists = True
    except ResourceNotFoundError:
        pass
    except Exception as e:
        raise Exception(f'Rule {rule_name}: {e}')

    if rule_exists:
        # The rule exists so we need to delete existing targets and then the rule
        if old_targets:
            events_client.remove_targets(
                Rule=rule_name,
                EventBusName=event_bus_name,
                Ids=[t['Id'] for t in old_targets],
                Force=True,
            )
            print(f'Rule {rule_name}: Deleted {len(old_targets)} old target(s)')
        events_client.delete_rule(Name=rule_name, EventBusName=event_bus_name, Force=True)
        print(f'Rule {rule_name}: Deleted old rule')

    put_rule_args = {
        'Name': rule_name,
        'Description': rule_spec['description'],
        'ScheduleExpression': rule_spec.get('schedule_expression'),
        'State': 'ENABLED' if rule_spec.get('enabled', False) else 'DISABLED',
        'RoleArn': rule_spec.get('role_arn'),
        'Tags': [{'Key': k, 'Value': v} for k, v in rule_spec.get('tags', {}).items()],
        'EventBusName': event_bus_name,
    }

    with suppress(KeyError):
        put_rule_args['EventPattern'] = json.dumps(rule_spec['event_pattern'])

    # Add the lava specials into the tags
    put_rule_args['Tags'].append({'Key': 'owner', 'Value': rule_spec['owner']})
    for k, v in rule_spec.items():
        if k.lower().startswith('x-'):
            put_rule_args['Tags'].append({'Key': k, 'Value': v})

    # Clean up tag keys and values
    for tag in put_rule_args['Tags']:
        tag['Key'] = TAG_BAD_CHARS.sub(TAG_SAFE_CHAR, tag['Key'])
        tag['Value'] = TAG_BAD_CHARS.sub(TAG_SAFE_CHAR, tag['Value'])

    # Create rule (no targets at this point)
    response = events_client.put_rule(**dict_strip(put_rule_args))
    print(f'Rule {rule_name}: Created new rule {response["RuleArn"]}')

    # Create the targets
    new_targets = rule_spec.get('targets')

    if new_targets:
        failed_targets = put_targets_for_rule(
            rule_name,
            targets=new_targets,
            event_bus_name=event_bus_name,
            events_client=events_client,
        )

        if failed_targets:
            for ft in failed_targets:
                print(
                    'Rule {rule_name}: Bad target: {TargetId}: {ErrorCode}: {ErrorMessage}'.format(
                        rule_name=rule_name, **ft
                    )
                )
                raise Exception('Bad target(s)')
        else:
            print(f'Rule {rule_name}: Created {len(new_targets)} new target(s)')


# ------------------------------------------------------------------------------
def main() -> int:
    """
    Do the business.

    :return:        Status
    """

    args = process_cli_args()
    aws_session = boto3.Session(profile_name=args.profile)

    for spec_file in args.rule_specs:
        try:
            do_spec_file(spec_file, aws_session=aws_session)
        except Exception as e:
            raise Exception(f'{spec_file}: {e}')

    return 0


# ------------------------------------------------------------------------------
if __name__ == '__main__':
    # Uncomment for debugging
    # exit(main())  # noqa: ERA001
    try:
        exit(main())
    except Exception as ex:
        print(f'{PROG}: {ex}', file=sys.stderr)
        exit(1)
